import matplotlib.pyplot as plt
import numpy as np


def forward(x):
    return w * x


def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) * (y_pred - y)


if __name__ == '__main__':
    x_data = [1.0, 2.0, 3.0]
    y_data = [2.0, 4.0, 6.0]

    w_list = []
    mse_list = []
    for w in np.arange(0.1, 4.0, 0.1):
        loss_val = 0
        w_list.append(w)
        print("w=", w)
        for x_val, y_val in zip(x_data, y_data):
            y_pred = forward(x_val)
            loss_val += loss(x_val, y_val)

        print("MSE=", loss_val/3)
        mse_list.append(loss_val/3)

    plt.figure()
    plt.plot(w_list, mse_list)
    plt.show()
